import os
import re
from collections import defaultdict
from datetime import datetime

def get_model_type(filename):
    """根据文件名确定模型类型"""
    filename_lower = filename.lower()
    if 'cpm' in filename_lower:
        return 'CPM'
    elif 'qwen' in filename_lower:
        return 'Qwen'
    elif 'pro' in filename_lower:
        return 'Pro'
    else:
        return 'Gemini'

def parse_result_file(file_path):
    """解析单个结果文件"""
    data = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
            
            # 提取各项数据
            difficulty_match = re.search(r'难度级别:\s*(\w+)', content)
            steps_match = re.search(r'总步数:\s*(\d+)', content)
            reward_match = re.search(r'总奖励:\s*([\d.]+)', content)
            invalid_match = re.search(r'无效动作次数:\s*(\d+)', content)
            success_match = re.search(r'成功到达目标:\s*(\S+)', content)
            time_match = re.search(r'保存时间:\s*(\d+)', content)
            
            if difficulty_match:
                data['difficulty'] = difficulty_match.group(1)
            if steps_match:
                data['steps'] = int(steps_match.group(1))
            if reward_match:
                data['reward'] = float(reward_match.group(1))
            if invalid_match:
                data['invalid_actions'] = int(invalid_match.group(1))
            if success_match:
                data['success'] = success_match.group(1) == '是'
            if time_match:
                data['save_time'] = time_match.group(1)
                
    except Exception as e:
        print(f"解析文件 {file_path} 时出错: {e}")
    
    return data

def print_stats_for_group(data_list, group_name, file_handle=None):
    """为一组数据打印统计信息"""
    if not data_list:
        return
    
    output = f"\n{group_name}:\n"
    output += f"  游戏次数: {len(data_list)}\n"
    
    success_rate = sum(1 for d in data_list if d.get('success', False)) / len(data_list) * 100
    output += f"  成功率: {success_rate:.1f}%\n"
    
    steps = [d['steps'] for d in data_list if 'steps' in d]
    if steps:
        output += f"  平均步数: {sum(steps)/len(steps):.1f}\n"
        output += f"  最少步数: {min(steps)}\n"
        output += f"  最多步数: {max(steps)}\n"
        
        # 计算去掉最值后的平均步数
        if len(steps) > 2:
            trimmed_steps = sorted(steps)[1:-1]  # 去掉最小和最大值
            output += f"  去掉最值后平均步数: {sum(trimmed_steps)/len(trimmed_steps):.1f}\n"
        elif len(steps) == 2:
            output += f"  去掉最值后平均步数: 无法计算(样本太少)\n"
    
    rewards = [d['reward'] for d in data_list if 'reward' in d]
    if rewards:
        output += f"  平均奖励: {sum(rewards)/len(rewards):.2f}\n"
        output += f"  最高奖励: {max(rewards):.2f}\n"
        output += f"  最低奖励: {min(rewards):.2f}\n"
    
    invalid = [d['invalid_actions'] for d in data_list if 'invalid_actions' in d]
    if invalid:
        output += f"  平均无效动作: {sum(invalid)/len(invalid):.1f}\n"
    
    print(output, end='')
    if file_handle:
        file_handle.write(output)

def analyze_results():
    """分析所有结果文件"""
    results_dir = ""
    
    if not os.path.exists(results_dir):
        print(f"目录不存在: {results_dir}")
        return
    
    # 创建输出文件
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"/maze_analysis_report_{timestamp}.txt"
    
    all_data = []
    model_stats = defaultdict(list)
    difficulty_stats = defaultdict(list)
    model_difficulty_stats = defaultdict(lambda: defaultdict(list))
    
    # 读取所有文件
    for filename in os.listdir(results_dir):
        file_path = os.path.join(results_dir, filename)
        if os.path.isfile(file_path):
            data = parse_result_file(file_path)
            if data:
                model_type = get_model_type(filename)
                data['model'] = model_type
                data['filename'] = filename
                
                all_data.append(data)
                model_stats[model_type].append(data)
                
                if 'difficulty' in data:
                    difficulty_stats[data['difficulty']].append(data)
                    model_difficulty_stats[model_type][data['difficulty']].append(data)
    
    if not all_data:
        print("未找到有效的统计文件")
        return
    
    # 同时输出到控制台和文件
    with open(output_file, 'w', encoding='utf-8') as f:
        # 输出统计结果
        header = "=" * 60 + "\n迷宫导航统计汇总报告\n" + "=" * 60 + "\n"
        summary = f"总文件数: {len(all_data)}\n"
        
        # 成功率统计
        success_count = sum(1 for data in all_data if data.get('success', False))
        success_summary = f"总成功率: {success_count}/{len(all_data)} ({success_count/len(all_data)*100:.1f}%)\n"
        
        print(header + summary + success_summary)
        f.write(header + summary + success_summary)
        
        # 按模型统计
        model_header = "\n" + "=" * 40 + "\n按模型类型统计\n" + "=" * 40
        print(model_header)
        f.write(model_header)
        
        for model_type in ['CPM', 'Qwen', 'Pro', 'Gemini']:
            if model_type in model_stats:
                print_stats_for_group(model_stats[model_type], f"{model_type} 模型", f)
        
        # 按难度级别统计
        difficulty_header = "\n" + "=" * 40 + "\n按难度级别统计\n" + "=" * 40
        print(difficulty_header)
        f.write(difficulty_header)
        
        for difficulty, data_list in difficulty_stats.items():
            print_stats_for_group(data_list, f"{difficulty.upper()} 难度", f)
        
        # 按模型和难度交叉统计
        cross_header = "\n" + "=" * 40 + "\n按模型和难度交叉统计\n" + "=" * 40
        print(cross_header)
        f.write(cross_header)
        
        for model_type in ['CPM', 'Qwen', 'Pro', 'Gemini']:
            if model_type in model_difficulty_stats:
                model_detail_header = f"\n{model_type} 模型详细统计:"
                print(model_detail_header)
                f.write(model_detail_header)
                
                for difficulty, data_list in model_difficulty_stats[model_type].items():
                    print_stats_for_group(data_list, f"  {difficulty.upper()} 难度", f)
    
    print(f"\n统计报告已保存到: {output_file}")

if __name__ == "__main__":
    analyze_results()
